"""
Jonathan Colmer, Samir Rein and Savannah Quick 2020-2022

Python script for environmental citizen complaints using collapsed Gibbs sampling
for Latent Dirichlet Allocation, using 10 topics.

We do not create wordclouds for full data, this is just to create numbers for table E1 in the paper.
"""

import pandas as pd
import topicmodels
import topicmodels.preprocess
import os
import numpy as np
import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS 
from collections import Counter #used for wordcloud
import random #grayscale
from paths import target_path, full_sample_LDA_path #importing target_path from paths.py


def main():
    print("Running: LDA Lemmas 10 Topics")

if __name__ == "__main__":
    main()

# Change to relevant directory
if os.path.exists(target_path):
    os.chdir(target_path)
else:
    print(f"Warning: {target_path} does not exist!")



#####
# LDA estimation
#####


#Import data
dataframe = pd.read_csv("Full_Sample/Processed_Data/Complaints_lemmatized.csv", encoding="utf-8")
dataframe = dataframe.replace(np.nan, '', regex=True)
docsobj2 = topicmodels.RawDocs(dataframe.IncidentDescriptionLemma, "long")

#Lemmatized word count
all_tokens = [s for d in docsobj2.tokens for s in d]
print("number of unique tokens = %d" % len(set(all_tokens)))
print("number of total tokens = %d" % len(all_tokens))

#Reprocess after lemmatization
docsobj2.token_clean(1)
docsobj2.stopword_remove("tokens")


###Change directory to save LDA Outputs in correct folder
os.chdir(full_sample_LDA_path)
docsobj2.term_rank("tokens")

###Plot word frequency measures, decide where to cut off dataset
plt.plot([x[1] for x in docsobj2.tfidf_ranking])
plt.savefig('tf_idf_ranking.png') 
plt.plot([x[1] for x in docsobj2.df_ranking])
plt.savefig('df_ranking.png') 


docsobj2.rank_remove("tfidf", "tokens", docsobj2.tfidf_ranking[12500][1])

all_tokens = [s for d in docsobj2.tokens for s in d]
print("number of unique tokens = %d" % len(set(all_tokens)))
print("number of total tokens = %d" % len(all_tokens))
